# -*- coding: utf-8 -*- 
# @Time : 2022/4/4 10:54 
# @Author : zzuxyj 
# @File : 13-nn-modelPretrianedAndModify.py
"""
模型参数预加载参数 和  修改现有的网络模型结构
"""
import torch
import torchvision

# 模型预训练
vgg16_true = torchvision.models.vgg16(pretrained=False) #pretrained=True
vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_true)
print(vgg16_false)

# 手动下载参数并进行模型加载预训练参数
modelParam = torch.load("./model/vgg16-pretrain.pth")
vgg16_true.load_state_dict(modelParam)
#print(vgg16_true)

# 修改现有网络模型:添加一层
vgg16_true.classifier.add_module("add_liner" , torch.nn.Linear(1000, 10))
print(vgg16_true)

# 修改现有网络模型：修改某一层
vgg16_false.classifier[6] = torch.nn.Linear(4096 , 10)
print(vgg16_false)



"""
#输出2
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ......
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    ......
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
#输出3
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   ......
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    ......
    (6): Linear(in_features=4096, out_features=1000, bias=True)
    (add_liner): Linear(in_features=1000, out_features=10, bias=True)
  )
)
#输出4
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ......
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    ......
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
)
"""